// Copyright Epic Games, Inc. All Rights Reserved.

#include "winsocktransport.h"

#if ZEN_WITH_PLUGINS

#	include <zencore/logging.h>
#	include <zencore/scopeguard.h>
#	include <zencore/workthreadpool.h>

#	if ZEN_PLATFORM_WINDOWS
#		include <zencore/windows.h>
ZEN_THIRD_PARTY_INCLUDES_START
#		include <WinSock2.h>
#		include <WS2tcpip.h>
ZEN_THIRD_PARTY_INCLUDES_END
#	endif

#	include <thread>

namespace zen {

struct SocketTransportConnection : public TransportConnection
{
public:
	SocketTransportConnection();
	~SocketTransportConnection();

	void Initialize(TransportServerConnection* ServerConnection, SOCKET ClientSocket);
	void HandleConnection();

	// TransportConnection

	virtual int64_t		WriteBytes(const void* Buffer, size_t DataSize) override;
	virtual void		Shutdown(bool Receive, bool Transmit) override;
	virtual void		CloseConnection() override;
	virtual const char* GetDebugName() override;

private:
	Ref<TransportServerConnection> m_ConnectionHandler;
	SOCKET						   m_ClientSocket{};
	bool						   m_IsTerminated = false;
};

SocketTransportConnection::SocketTransportConnection()
{
}

SocketTransportConnection::~SocketTransportConnection()
{
}

void
SocketTransportConnection::Initialize(TransportServerConnection* ServerConnection, SOCKET ClientSocket)
{
	ZEN_ASSERT(!m_ConnectionHandler);

	m_ConnectionHandler = ServerConnection;
	m_ClientSocket		= ClientSocket;
}

void
SocketTransportConnection::HandleConnection()
{
	ZEN_ASSERT(m_ConnectionHandler);

	const int InputBufferSize = 64 * 1024;
	uint8_t*  InputBuffer	  = new uint8_t[64 * 1024];
	auto	  _				  = MakeGuard([&] { delete[] InputBuffer; });

	do
	{
		const int RecvBytes = recv(m_ClientSocket, (char*)InputBuffer, InputBufferSize, /* flags */ 0);

		if (RecvBytes == 0)
		{
			// Connection closed
			return CloseConnection();
		}
		else if (RecvBytes < 0)
		{
			// Error
			return CloseConnection();
		}

		m_ConnectionHandler->OnBytesRead(InputBuffer, RecvBytes);
	} while (m_ClientSocket);
}

void
SocketTransportConnection::CloseConnection()
{
	if (m_IsTerminated)
	{
		return;
	}

	ZEN_ASSERT(m_ClientSocket);
	m_IsTerminated = true;

	shutdown(m_ClientSocket, SD_BOTH);	// We won't be sending or receiving any more data

	closesocket(m_ClientSocket);
	m_ClientSocket = 0;
}

const char*
SocketTransportConnection::GetDebugName()
{
	return nullptr;
}

int64_t
SocketTransportConnection::WriteBytes(const void* Buffer, size_t DataSize)
{
	const uint8_t* BufferCursor	  = reinterpret_cast<const uint8_t*>(Buffer);
	int64_t		   TotalBytesSent = 0;

	while (DataSize)
	{
		const int MaxBlockSize	= 128 * 1024;
		const int SendBlockSize = (DataSize > MaxBlockSize) ? MaxBlockSize : (int)DataSize;
		const int SentBytes		= send(m_ClientSocket, (const char*)BufferCursor, SendBlockSize, /* flags */ 0);

		if (SentBytes < 0)
		{
			// Error
			return SentBytes;
		}

		BufferCursor += SentBytes;
		DataSize -= SentBytes;
		TotalBytesSent += SentBytes;
	}

	return TotalBytesSent;
}

void
SocketTransportConnection::Shutdown(bool Receive, bool Transmit)
{
	if (Receive)
	{
		if (Transmit)
		{
			shutdown(m_ClientSocket, SD_BOTH);
		}
		else
		{
			shutdown(m_ClientSocket, SD_RECEIVE);
		}
	}
	else if (Transmit)
	{
		shutdown(m_ClientSocket, SD_SEND);
	}
}

//////////////////////////////////////////////////////////////////////////

class SocketTransportPluginImpl : public TransportPlugin, RefCounted
{
public:
	SocketTransportPluginImpl();
	~SocketTransportPluginImpl();

	virtual uint32_t	AddRef() const override;
	virtual uint32_t	Release() const override;
	virtual void		Configure(const char* OptionTag, const char* OptionValue) override;
	virtual void		Initialize(TransportServer* ServerInterface) override;
	virtual void		Shutdown() override;
	virtual const char* GetDebugName() override;
	virtual bool		IsAvailable() override;

private:
	TransportServer* m_ServerInterface = nullptr;
	uint16_t		 m_BasePort		   = 8558;
	int				 m_ThreadCount	   = 8;
	bool			 m_IsOk			   = true;

	SOCKET							  m_ListenSocket{};
	std::thread						  m_AcceptThread;
	std::atomic_flag				  m_KeepRunning;
	std::unique_ptr<WorkerThreadPool> m_WorkerThreadpool;
};

SocketTransportPluginImpl::SocketTransportPluginImpl()
{
#	if ZEN_PLATFORM_WINDOWS
	WSADATA wsaData;
	if (int Result = WSAStartup(0x202, &wsaData); Result != 0)
	{
		m_IsOk = false;
		WSACleanup();
	}
#	endif
}

SocketTransportPluginImpl::~SocketTransportPluginImpl()
{
	Shutdown();

#	if ZEN_PLATFORM_WINDOWS
	if (m_IsOk)
	{
		WSACleanup();
	}
#	endif
}

uint32_t
SocketTransportPluginImpl::AddRef() const
{
	return RefCounted::AddRef();
}

uint32_t
SocketTransportPluginImpl::Release() const
{
	return RefCounted::Release();
}

void
SocketTransportPluginImpl::Configure(const char* OptionTag, const char* OptionValue)
{
	using namespace std::literals;

	if (OptionTag == "port"sv)
	{
		if (auto PortNum = ParseInt<uint16_t>(OptionValue))
		{
			m_BasePort = *PortNum;
		}
	}
	else if (OptionTag == "threads"sv)
	{
		if (auto ThreadCount = ParseInt<int>(OptionValue))
		{
			m_ThreadCount = *ThreadCount;
		}
	}
	else
	{
		// Unknown configuration option
	}
}

bool
SocketTransportPluginImpl::IsAvailable()
{
	return true;
}

void
SocketTransportPluginImpl::Initialize(TransportServer* ServerInterface)
{
	m_ServerInterface  = ServerInterface;
	m_WorkerThreadpool = std::make_unique<WorkerThreadPool>(m_ThreadCount, "http_conn");

	m_ListenSocket = socket(AF_INET6, SOCK_STREAM, 0);

	if (m_ListenSocket == SOCKET_ERROR || m_ListenSocket == INVALID_SOCKET)
	{
		ZEN_ERROR("socket creation failed in HTTP plugin server init: {}", WSAGetLastError());

		return;
	}

	sockaddr_in6 Server{};
	Server.sin6_family = AF_INET6;
	Server.sin6_port   = htons(m_BasePort);
	Server.sin6_addr   = in6addr_any;

	if (int Result = bind(m_ListenSocket, (sockaddr*)&Server, sizeof(Server)); Result == SOCKET_ERROR)
	{
		ZEN_ERROR("bind call failed in HTTP plugin server init: {}", WSAGetLastError());

		return;
	}

	if (int Result = listen(m_ListenSocket, AF_INET6); Result == SOCKET_ERROR)
	{
		ZEN_ERROR("listen call failed in HTTP plugin server init: {}", WSAGetLastError());

		return;
	}

	m_KeepRunning.test_and_set();

	m_AcceptThread = std::thread([&] {
		SetCurrentThreadName("http_plugin_acceptor");

		ZEN_INFO("HTTP plugin server waiting for connections");

		do
		{
			if (SOCKET ClientSocket = accept(m_ListenSocket, NULL, NULL); ClientSocket != SOCKET_ERROR)
			{
				int Flag = 1;
				setsockopt(ClientSocket, IPPROTO_TCP, TCP_NODELAY, (char*)&Flag, sizeof(Flag));

				// Handle new connection
				SocketTransportConnection* Connection = new SocketTransportConnection();
				TransportServerConnection* ConnectionInterface{m_ServerInterface->CreateConnectionHandler(Connection)};
				Connection->Initialize(ConnectionInterface, ClientSocket);

				m_WorkerThreadpool->ScheduleWork([Connection] {
					try
					{
						Connection->HandleConnection();
					}
					catch (const std::exception& Ex)
					{
						ZEN_WARN("exception caught in connection loop: {}", Ex.what());
					}

					delete Connection;
				});
			}
			else
			{
			}
		} while (!IsApplicationExitRequested() && m_KeepRunning.test());

		ZEN_INFO("HTTP plugin server accept thread exit");
	});
}

void
SocketTransportPluginImpl::Shutdown()
{
	// TODO: all pending/ongoing work should be drained here as well

	m_KeepRunning.clear();

	if (m_ListenSocket)
	{
		closesocket(m_ListenSocket);
		m_ListenSocket = 0;
	}

	if (m_AcceptThread.joinable())
	{
		m_AcceptThread.join();
	}
}

const char*
SocketTransportPluginImpl::GetDebugName()
{
	return nullptr;
}

//////////////////////////////////////////////////////////////////////////

TransportPlugin*
CreateSocketTransportPlugin()
{
	return new SocketTransportPluginImpl;
}

}  // namespace zen

#endif
